{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example: LLMs Learn to Navigate Mazes from Experience (BabyAI Benchmark)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import random\n",
    "from typing import List, Optional, Tuple\n",
    "from uuid import UUID\n",
    "\n",
    "import altair as alt\n",
    "import pandas as pd\n",
    "import yaml\n",
    "from balrog.environments import make_env\n",
    "from omegaconf import OmegaConf\n",
    "from tensorzero import AsyncTensorZeroGateway\n",
    "from tensorzero.util import uuid7\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load config for BALROG environments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"config.yml\") as f:\n",
    "    config_dict = yaml.safe_load(f)\n",
    "config = OmegaConf.create(config_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce this value if you're getting rate-limited by OpenAI\n",
    "MAX_CONCURRENT_T0_REQUESTS = 50\n",
    "semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper Functions\n",
    "\n",
    "The `run_episode` function executes a single episode of the agent for a BabyAI task (maze navigation game)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def run_episode(\n",
    "    client: AsyncTensorZeroGateway,\n",
    "    variant_name: str,\n",
    "    env_name: str,\n",
    "    task_name: str,\n",
    "    episode_idx: int,\n",
    "    config: OmegaConf,\n",
    "    semaphore: asyncio.Semaphore,\n",
    "    history_length: int = 2,\n",
    "    seed: int = 0,\n",
    "    test: bool = False,\n",
    ") -> Tuple[float, float, Optional[UUID]]:\n",
    "    episode_log = {\n",
    "        \"variant\": variant_name,\n",
    "        \"task\": task_name,\n",
    "        \"input_tokens\": 0,\n",
    "        \"output_tokens\": 0,\n",
    "    }\n",
    "    use_history = \"history\" in variant_name\n",
    "    episode_id = uuid7()\n",
    "    env = make_env(env_name, task_name, config)\n",
    "    obs, _ = env.reset(seed=episode_idx + seed)\n",
    "    mission = obs[\"mission\"]\n",
    "    episode_return = 0\n",
    "    history = []\n",
    "    for step in range(env.max_steps):\n",
    "        # Generate action\n",
    "        try:\n",
    "            async with semaphore:\n",
    "                # Generate message content\n",
    "                state = obs[\"text\"][\"long_term_context\"]\n",
    "                # Generate action given message content\n",
    "                response = await client.inference(\n",
    "                    function_name=\"act\",\n",
    "                    variant_name=variant_name,\n",
    "                    input={\n",
    "                        \"system\": {\n",
    "                            \"mission\": mission,\n",
    "                        },\n",
    "                        \"messages\": [\n",
    "                            {\n",
    "                                \"role\": \"user\",\n",
    "                                \"content\": [\n",
    "                                    {\n",
    "                                        \"type\": \"text\",\n",
    "                                        \"arguments\": {\n",
    "                                            \"observation\": state,\n",
    "                                            \"history\": \"\\n\".join(history[-history_length:]),\n",
    "                                        },\n",
    "                                    }\n",
    "                                ],\n",
    "                            }\n",
    "                        ],\n",
    "                    },\n",
    "                    episode_id=episode_id,\n",
    "                    cache_options={\"enabled\": \"on\"},\n",
    "                )\n",
    "                episode_log[\"input_tokens\"] += response.usage.input_tokens\n",
    "                episode_log[\"output_tokens\"] += response.usage.output_tokens\n",
    "            action = response.output.parsed[\"action\"]\n",
    "            # Check if action is valid and set to default if not\n",
    "            action = env.check_action_validity(action)\n",
    "        except Exception as e:\n",
    "            # Handle error\n",
    "            print(f\"Error occurred: {type(e).__name__}: {e}\")\n",
    "            print(\"Choosing a random legal move as fallback.\")\n",
    "            action = random.choice(\n",
    "                [\n",
    "                    \"turn left\",\n",
    "                    \"turn right\",\n",
    "                    \"go forward\",\n",
    "                    \"pick up\",\n",
    "                    \"drop\",\n",
    "                    \"toggle\",\n",
    "                ]\n",
    "            )\n",
    "        # Update history\n",
    "        if use_history:\n",
    "            history.append(f\"Observation:{state}\\n\\nYour Response:\\n{action}\\n\")\n",
    "        # Interact with environment\n",
    "        obs, reward, terminated, truncated, info = env.step(action)\n",
    "        # Update episode return\n",
    "        episode_return += reward\n",
    "        # Check if episode is done and break if so\n",
    "        done = terminated or truncated\n",
    "        if done:\n",
    "            break\n",
    "    # See if episode is successful\n",
    "    progression = env.get_stats()[\"progression\"]\n",
    "    # Log feedback\n",
    "    await client.feedback(\n",
    "        metric_name=\"episode_return\",\n",
    "        episode_id=episode_id,\n",
    "        value=episode_return,\n",
    "        dryrun=test,\n",
    "    )\n",
    "    await client.feedback(\n",
    "        metric_name=\"progression\",\n",
    "        episode_id=episode_id,\n",
    "        value=progression,\n",
    "        dryrun=test,\n",
    "    )\n",
    "    episode_log[\"episode_return\"] = episode_return\n",
    "    episode_log[\"num_steps\"] = step + 1\n",
    "    episode_log[\"failed_candidates\"] = env.failed_candidates\n",
    "    episode_log.update(env.get_stats())\n",
    "    episode_log[\"seed\"] = episode_idx\n",
    "    episode_log[\"episode_id\"] = episode_id\n",
    "    return episode_log"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define a function to run multiple episodes of the agent for a BabyAI task in parallel. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def run_episodes(\n",
    "    client: AsyncTensorZeroGateway,\n",
    "    variant_name: str,\n",
    "    env_name: str,\n",
    "    task_name: str,\n",
    "    num_episodes: int,\n",
    "    config: OmegaConf,\n",
    "    semaphore: asyncio.Semaphore,\n",
    "    disable_progress_bar: bool = False,\n",
    "    history_length: int = 2,\n",
    "    seed: int = 0,\n",
    "    test: bool = False,\n",
    ") -> Tuple[List[float], List[float]]:\n",
    "    progress_bar = trange(\n",
    "        num_episodes,\n",
    "        desc=f\"{env_name} {task_name} {variant_name}\",\n",
    "        disable=disable_progress_bar,\n",
    "    )\n",
    "\n",
    "    tasks = [\n",
    "        asyncio.create_task(\n",
    "            run_episode(\n",
    "                client=client,\n",
    "                variant_name=variant_name,\n",
    "                env_name=env_name,\n",
    "                task_name=task_name,\n",
    "                episode_idx=episode_idx,\n",
    "                config=config,\n",
    "                semaphore=semaphore,\n",
    "                history_length=history_length,\n",
    "                seed=seed,\n",
    "                test=test,\n",
    "            )\n",
    "        )\n",
    "        for episode_idx in range(num_episodes)\n",
    "    ]\n",
    "\n",
    "    num_successes = 0\n",
    "    episode_logs = []\n",
    "    for task in asyncio.as_completed(tasks):\n",
    "        episode_log = await task\n",
    "        if episode_log[\"progression\"] == 1.0:\n",
    "            num_successes += 1\n",
    "        episode_logs.append(episode_log)\n",
    "        current = len(episode_logs)\n",
    "        progress_bar.update(1)\n",
    "        progress_bar.set_postfix(\n",
    "            {\"Success\": f\"{num_successes}/{current}\"},\n",
    "            refresh=True,\n",
    "        )\n",
    "    progress_bar.close()\n",
    "    return episode_logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 200\n",
    "num_episodes = 20\n",
    "task_names = config.tasks.babyai_tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baseline\n",
    "\n",
    "The `baseline` variant uses a simple system prompt that guides the LLM to navigate the maze.\n",
    "\n",
    "You can find the prompts in `config/functions/act/baseline`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_baseline = []\n",
    "\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"baseline\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            seed=seed,\n",
    "            test=True,\n",
    "        )\n",
    "        results_baseline.extend(results_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reasoning\n",
    "\n",
    "The `reasoning` variant uses a system prompt that guides the LLM to reason about the best course of action.\n",
    "\n",
    "You can find the prompts in `config/functions/act/reasoning`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_reasoning = []\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"reasoning\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            seed=seed,\n",
    "            test=True,\n",
    "        )\n",
    "        results_reasoning.extend(results_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## History\n",
    "\n",
    "The `history` variant uses the previous observations and actions to guide the LLM to navigate the maze.\n",
    "We add the previous two observations and actions to the field `history` in the examples below.\n",
    "\n",
    "You can find the prompts in `config/functions/act/history`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_length = 8\n",
    "\n",
    "results_history = []\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"history\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            history_length=history_length,\n",
    "            seed=seed,\n",
    "            test=True,\n",
    "        )\n",
    "        results_history.extend(results_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## History and Reasoning\n",
    "\n",
    "The `history_and_reasoning` variant combines the reasoning variant and the history variant.\n",
    "\n",
    "You can find the prompts in `config/functions/act/history_and_reasoning`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_history_and_reasoning = []\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"history_and_reasoning\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            history_length=history_length,\n",
    "            seed=seed,\n",
    "            test=True,\n",
    "        )\n",
    "        results_history_and_reasoning.extend(results_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results_baseline + results_reasoning + results_history + results_history_and_reasoning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Success Rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"progression\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Task Success Rate\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"episode_return\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Return\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"num_steps\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Length\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Generated Token Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"output_tokens\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Generated Token Count\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Input Token Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"input_tokens\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Input Token Count\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Improving Performance with Supervised Fine-tuning (SFT)\n",
    "\n",
    "The results above show that the `history_and_reasoning` variant yields the best success rate.\n",
    "Here we describe how to improve the performance of the `history_and_reasoning` variant by fine-tuning it on a separate set of random episodes.\n",
    "\n",
    "First we run a large set of episodes for each task using the `history_and_reasoning` variant to generate data for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_episodes_ft = 200\n",
    "seed_ft = 0\n",
    "\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"history_and_reasoning\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes_ft,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            history_length=history_length,\n",
    "            seed=seed_ft,\n",
    "            test=False,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "We provide two option for fine-tuning a model: using a notebook or using the TensorZero UI.\n",
    "You can fine-tune on episodes that successfully completed the task, or episodes that achieved a sufficiently high return (e.g. 0.7).\n",
    "\n",
    "See the `README.md` file for more details.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluating the Fine-tuned Variant\n",
    "\n",
    "After fine-tuning, create a `history_and_reasoning_sft` variant and run the following code to evaluate it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_history_and_reasoning_ft = []\n",
    "for task_name in task_names:\n",
    "    async with await AsyncTensorZeroGateway.build_http(gateway_url=\"http://localhost:3000\", timeout=180.0) as client:\n",
    "        results_task = await run_episodes(\n",
    "            client=client,\n",
    "            variant_name=\"history_and_reasoning_sft\",\n",
    "            env_name=\"babyai\",\n",
    "            task_name=task_name,\n",
    "            num_episodes=num_episodes,\n",
    "            config=config,\n",
    "            semaphore=semaphore,\n",
    "            disable_progress_bar=False,\n",
    "            history_length=history_length,\n",
    "            seed=seed,\n",
    "            test=True,\n",
    "        )\n",
    "        results_history_and_reasoning_ft.extend(results_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "\n",
    "We combine the results of the fine-tuned model with the results of the `history_and_reasoning` variant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ft = pd.DataFrame(results_history_and_reasoning_ft)\n",
    "\n",
    "df = pd.concat([df, df_ft])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see below that the fine-tuned model performs better than the `history_and_reasoning` variant!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Success Rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"progression\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Task Success Rate\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"episode_return\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Return\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"num_steps\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Length\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Generated Token Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"output_tokens\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Generated Token Count\")\n",
    "\n",
    "chart"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Episode Input Token Count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = df.groupby(\"variant\")[\"input_tokens\"].agg([\"mean\", \"sem\"]).reset_index()\n",
    "\n",
    "# Create a base chart\n",
    "bars = (\n",
    "    alt.Chart(summary)\n",
    "    .encode(\n",
    "        y=alt.Y(\"variant:N\", title=\"Variant\"),\n",
    "        x=alt.X(\"mean:Q\", title=\"Value ± 1 SEM\", scale=alt.Scale(zero=False)),\n",
    "    )\n",
    "    .mark_bar(color=\"#1f77b4\")\n",
    ")\n",
    "\n",
    "# Create error bars\n",
    "error_bars = (\n",
    "    alt.Chart(summary)\n",
    "    .mark_errorbar(color=\"black\")\n",
    "    .encode(y=\"variant:N\", x=alt.X(\"low:Q\", title=\"Value ± 1 SEM\"), x2=\"high:Q\")\n",
    "    .transform_calculate(low=\"datum.mean - datum.sem\", high=\"datum.mean + datum.sem\")\n",
    ")\n",
    "\n",
    "# Combine the layers\n",
    "chart = (bars + error_bars).properties(title=\"Episode Input Token Count\")\n",
    "\n",
    "chart"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
